import numpy as np 
import torch 

def unravel_index(index, shape):
    unravel_idx = []
    for dim in reversed(shape):
        unravel_idx.append(index % dim)
        index //= dim
    return list(reversed(unravel_idx))

def select_subtensor(y, indices):
    sub_y = y.clone()
    for i in range(len(indices)):
        sub_y = sub_y[indices[i]]
    return sub_y

def gradient_reciprocal_sum(y, x):
    grad_sum = torch.zeros_like(y)
    indices = torch.arange(y.numel())
    for idx in indices:
        index = unravel_index(idx, y.shape)
        grad = torch.autograd.grad(select_subtensor(y,index), x, create_graph=True)[0]
        grad_nz = torch.masked_select(grad, grad != 0.0)
        grad_sum[index] = torch.sum(1.0 / grad_nz)
    return grad_sum

def Calculate_fgrad(x,y):
    f_2_ = torch.autograd.grad(x[:,:,:,2],y,grad_outputs=torch.ones_like(x[:,:,:,2]),create_graph=True)[0] 
    _f_grad = torch.chunk(f_2_, 3, dim=-1)[1]
    return -_f_grad

def inverse_combine(f_grad,group):
    y0 = []
    for i in range(group):
        if i == 0:
            y0.append(torch.ones_like(f_grad))
        else:
            y0.append(f_grad)
    dy0_dx = torch.cat(y0, dim=-1)

    y1 = []
    for i in range(group):
        if i == 0:
            y1.append(f_grad)
        elif i == 1:
            y1.append(f_grad ** 2 + 1)
        elif i == 2:
            y1.append(f_grad ** 2)
    dy1_dx = torch.cat(y1, dim=-1)

    y2 = []
    for i in range(group):
        if i == 0:
            y2.append(f_grad ** 2)
        elif i == 1:
            y2.append(f_grad + f_grad ** 3)
        elif i == 2:
            y2.append(1 + f_grad ** 3)
    dy2_dx = torch.cat(y2, dim=-1)    
    dy_dx = dy0_dx + dy1_dx + dy2_dx

    return dy_dx

def Calculate_fgrad_SNN(x,y):
    x_list = torch.chunk(x, 2, dim=-1)
    f_1_ = torch.autograd.grad(x_list[0],y,grad_outputs=torch.ones_like(x_list[0]),retain_graph= True)[0]
    dx1_dy = torch.chunk(f_1_, 2, dim=-1)
    dx1_dy1 = dx1_dy[0]
    f_2_ = torch.autograd.grad(x_list[1],y,grad_outputs=torch.ones_like(x_list[1]))[0] 
    dx2_dy = torch.chunk(f_2_, 2, dim=-1)
    dx2_dy1 = dx2_dy[0]   
    return dx1_dy1,dx2_dy1

def recombine_gradient(dy2_dx2,dy2_dx1,dy1_dx1,dy1_dx2,grad_out):

    grad_out_list = torch.chunk(grad_out, 2, dim=-1)


    dy2_dx2_g = dy2_dx2 * grad_out_list[1]

    dy2_dx1_g = dy2_dx1 * grad_out_list[1]

    dy1_dx1_g = dy1_dx1 * grad_out_list[0]

    dy1_dx2_g = dy1_dx2 * grad_out_list[0]

    dy_dx1 = dy1_dx1_g + dy2_dx1_g
    dy_dx2 = dy1_dx2_g + dy2_dx2_g

    dy_dx = []
    dy_dx.append(dy_dx1)
    dy_dx.append(dy_dx2)

    dy_dx = torch.cat(dy_dx, dim=-1)
    return dy_dx

